{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7yuytuIllsv1"
      },
      "source": [
        "# Trax Quick Intro\n",
        "\n",
        "We train **Trax Transformer** (or Reformer) on a simple copy problem and run inference.\n",
        "* See how to create your inputs from python.\n",
        "* Learn how to run the trainer.\n",
        "* Run fast inference with Transformer.\n",
        "\n",
	"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "BIl27504La0G"
      },
      "source": [
        "## General Setup\n",
        "Execute the following few cells (once) before running any of the code samples in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "oILRLCWN_16u"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "# Copyright 2020 Google LLC.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "vlGjGoGMTt-D"
      },
      "outputs": [],
      "source": [
        "#@title\n",
        "# Import Trax\n",
        "\n",
        "! pip install -q -U trax\n",
        "! pip install -q tensorflow\n",
        "\n",
        "import trax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-LQ89rFFsEdk"
      },
      "source": [
        "# Transformer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "height": 68
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 318,
          "status": "ok",
          "timestamp": 1578963024402,
          "user": {
            "displayName": "Lukasz Kaiser",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64",
            "userId": "13267693649565518272"
          },
          "user_tz": 480
        },
        "id": "djTiSLcaNFGa",
        "outputId": "610b5f32-e47e-4afd-971f-3e33b03adc0f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Inputs[0]:  [ 0  6 13 29 22  0  6 13 29 22]\n",
            "Targets[0]: [ 0  6 13 29 22  0  6 13 29 22]\n",
            "Mask[0]:    [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]\n"
          ]
        }
      ],
      "source": [
        "# Construct inputs, see one batch\n",
        "def copy_task(batch_size, vocab_size, length):\n",
        "  \"\"\"This task is to copy a random string w, so the input is 0w0w.\"\"\"\n",
        "  while True:\n",
        "    assert length % 2 == 0\n",
        "    w_length = (length // 2) - 1\n",
        "    w = np.random.randint(low=1, high=vocab_size-1,\n",
        "                          size=(batch_size, w_length))\n",
        "    zero = np.zeros([batch_size, 1], np.int32)\n",
        "    loss_weights = np.concatenate([np.zeros((batch_size, w_length)),\n",
        "                                   np.ones((batch_size, w_length+2))], axis=1)\n",
        "    x = np.concatenate([zero, w, zero, w], axis=1)\n",
        "    yield (x, x, loss_weights)  # Here inputs and targets are the same.\n",
        "copy_inputs = trax.supervised.Inputs(lambda _: copy_task(16, 32, 10))\n",
        "\n",
        "# Peek into the inputs.\n",
        "data_stream = copy_inputs.train_stream(1)\n",
        "inputs, targets, mask = next(data_stream)\n",
        "print(\"Inputs[0]:  %s\" % str(inputs[0]))\n",
        "print(\"Targets[0]: %s\" % str(targets[0]))\n",
        "print(\"Mask[0]:    %s\" % str(mask[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "height": 629
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 28950,
          "status": "ok",
          "timestamp": 1578963053368,
          "user": {
            "displayName": "Lukasz Kaiser",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64",
            "userId": "13267693649565518272"
          },
          "user_tz": 480
        },
        "id": "kSauPt0NUl_o",
        "outputId": "e5cf66a1-5deb-43f3-fb0a-a5fa479c7cbf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Step    500: Ran 500 train steps in 16.51 secs\n",
            "Step    500: Evaluation\n",
            "Step    500: train                   accuracy |  0.53125000\n",
            "Step    500: train                       loss |  1.83887446\n",
            "Step    500: train         neg_log_perplexity | -1.83887446\n",
            "Step    500: train weights_per_batch_per_core |  80.00000000\n",
            "Step    500: eval                    accuracy |  0.52500004\n",
            "Step    500: eval                        loss |  1.92791247\n",
            "Step    500: eval          neg_log_perplexity | -1.92791247\n",
            "Step    500: eval  weights_per_batch_per_core |  80.00000000\n",
            "Step    500: Finished evaluation\n",
            "\n",
            "Step   1000: Ran 500 train steps in 2.54 secs\n",
            "Step   1000: Evaluation\n",
            "Step   1000: train                   accuracy |  1.00000000\n",
            "Step   1000: train                       loss |  0.00707983\n",
            "Step   1000: train         neg_log_perplexity | -0.00707983\n",
            "Step   1000: train weights_per_batch_per_core |  80.00000000\n",
            "Step   1000: eval                    accuracy |  1.00000000\n",
            "Step   1000: eval                        loss |  0.01029818\n",
            "Step   1000: eval          neg_log_perplexity | -0.01029818\n",
            "Step   1000: eval  weights_per_batch_per_core |  80.00000000\n",
            "Step   1000: Finished evaluation\n",
            "\n",
            "Step   1500: Ran 500 train steps in 2.46 secs\n",
            "Step   1500: Evaluation\n",
            "Step   1500: train                   accuracy |  1.00000000\n",
            "Step   1500: train                       loss |  0.00037777\n",
            "Step   1500: train         neg_log_perplexity | -0.00037777\n",
            "Step   1500: train weights_per_batch_per_core |  80.00000000\n",
            "Step   1500: eval                    accuracy |  1.00000000\n",
            "Step   1500: eval                        loss |  0.00037660\n",
            "Step   1500: eval          neg_log_perplexity | -0.00037660\n",
            "Step   1500: eval  weights_per_batch_per_core |  80.00000000\n",
            "Step   1500: Finished evaluation\n"
          ]
        }
      ],
      "source": [
        "# Transformer LM\n",
        "def tiny_transformer_lm(mode):\n",
        "  return trax.models.TransformerLM(   # You can try trax_models.ReformerLM too.\n",
        "    d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)\n",
        "\n",
        "# Train tiny model with Trainer.\n",
        "output_dir = os.path.expanduser('~/train_dir/')\n",
        "!rm -f ~/train_dir/model.pkl  # Remove old model.\n",
        "trainer = trax.supervised.Trainer(\n",
        "    model=tiny_transformer_lm,\n",
        "    loss_fn=trax.layers.CrossEntropyLoss,\n",
        "    optimizer=trax.optimizers.Adafactor,  # Change optimizer params here.\n",
        "    lr_schedule=trax.lr.MultifactorSchedule,  # Change lr schedule here.\n",
        "    inputs=copy_inputs,\n",
        "    output_dir=output_dir,\n",
        "    has_weights=True)  # Because we have loss mask, this API may change.\n",
        "\n",
        "# Train for 3 epochs each consisting of 500 train batches, eval on 2 batches.\n",
        "n_epochs  = 3\n",
        "train_steps = 500\n",
        "eval_steps = 2\n",
        "for _ in range(n_epochs):\n",
        "  trainer.train_epoch(train_steps, eval_steps)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "height": 34
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 1190,
          "status": "ok",
          "timestamp": 1578963686769,
          "user": {
            "displayName": "Lukasz Kaiser",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mC8pChl87HbK_eOtVhtNPwUVx8btvfyYzH9UHn3=s64",
            "userId": "13267693649565518272"
          },
          "user_tz": 480
        },
        "id": "cqjYoxPEu8PG",
        "outputId": "d5a99b97-843b-4761-a711-ec337098e30e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]\n"
          ]
        }
      ],
      "source": [
        "# Initialize model for inference.\n",
        "predict_model = tiny_transformer_lm(mode='predict')\n",
        "predict_signature = trax.shapes.ShapeDtype((1,1), dtype=np.int32)\n",
        "predict_model.init(predict_signature)\n",
        "predict_model.init_from_file(os.path.join(output_dir, \"model.pkl\"),\n",
        "                             weights_only=True)\n",
        "# You can also do: predict_model.weights = trainer.model_weights\n",
        "\n",
        "# Run inference\n",
        "prefix = [0, 1, 2, 3, 4, 0]   # Change non-0 digits to see if it's copying\n",
        "cur_input = np.array([[0]])\n",
        "result = []\n",
        "for i in range(10):\n",
        "  logits = predict_model(cur_input)\n",
        "  next_input = np.argmax(logits[0, 0, :], axis=-1)\n",
        "  if i \u003c len(prefix) - 1:\n",
        "    next_input = prefix[i]\n",
        "  cur_input = np.array([[next_input]])\n",
        "  result.append(int(next_input))  # Append to the result\n",
        "print(result)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "Trax Quick Intro",
      "provenance": [
        {
          "file_id": "1v1GvTkEFjMH_1c-bdS7JzNS70u9RUEHV",
          "timestamp": 1578964243645
        },
        {
          "file_id": "1SplqILjJr_ZqXcIUkNIk0tSbthfhYm07",
          "timestamp": 1572044421118
        },
        {
          "file_id": "intro.ipynb",
          "timestamp": 1571858674399
        },
        {
          "file_id": "1sF8QbqJ19ZU6oy5z4GUTt4lgUCjqO6kt",
          "timestamp": 1569980697572
        },
        {
          "file_id": "1EH76AWQ_pvT4i8ZXfkv-SCV4MrmllEl5",
          "timestamp": 1563927451951
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
